[DAGMM] DAGMM: for arrhythmia data set

Author

kione kim

Published

October 19, 2023

Deep Autoencoding Gaussian Mixture Model for Arrhythmia dataset

### imports
import torch
from torch import nn
import numpy as np
import pandas as pd
import argparse
import sys
### data 파일
file_path = 'C:\\Users\\UOS\\Desktop\\연구\\5. 데이터\\data\\arrhythmia\\arrhythmia.data'

df = pd.read_csv(file_path, header=None)
df = df.replace('?', 0)
df = df.astype('float64')

data_array = df.values
data_array = torch.autograd.Variable(torch.from_numpy(data_array).float())
data_array.shape
torch.Size([452, 280])
parser = argparse.ArgumentParser(description='parser for argparse test')

parser.add_argument('--input_dim', type=int, default=data_array.shape[-1])
parser.add_argument('--enc_hidden_dim', type=str, default='10,2')
parser.add_argument('--dec_hidden_dim', type=str, default='10')
parser.add_argument('--est_hidden_dim', type=str, default='4, 10, 2')
parser.add_argument('--dropout', action='store_true', default=0.5)
parser.add_argument('--learning_rate', type=float, default=0.001)
parser.add_argument('--num_epoch', type=int, default=10)

if 'ipykernel_launcher' in sys.argv[0]:
    sys.argv = [sys.argv[0]]  

args = parser.parse_args()

enc_hidden_dim = args.enc_hidden_dim.split(',')
dec_hidden_dim = args.dec_hidden_dim.split(',')
est_hidden_dim = args.est_hidden_dim.split(',')

args.enc_hidden_dim_list = []
args.dec_hidden_dim_list = []
args.est_hidden_dim_list = []

args.enc_hidden_dim_list.append(args.input_dim)

for i in enc_hidden_dim:
    args.enc_hidden_dim_list.append(int(i))

args.enc_hidden_dim_list

args.dec_hidden_dim_list.append(args.enc_hidden_dim_list[-1])

for i in dec_hidden_dim:
    args.dec_hidden_dim_list.append(int(i))

args.dec_hidden_dim_list.append(args.input_dim)

args.dec_hidden_dim_list

for i in est_hidden_dim:
    args.est_hidden_dim_list.append(int(i))

args.est_hidden_dim_list

args
Namespace(input_dim=280, enc_hidden_dim='10,2', dec_hidden_dim='10', est_hidden_dim='4, 10, 2', dropout=0.5, learning_rate=0.001, num_epoch=10, enc_hidden_dim_list=[280, 10, 2], dec_hidden_dim_list=[2, 10, 280], est_hidden_dim_list=[4, 10, 2])
### compresssion network
class midlayer(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(midlayer, self).__init__()
        self.fc_layer   = nn.Linear(input_dim, hidden_dim)
        self.activation = nn.Tanh()
    
    def forward(self, input):
        out = self.fc_layer(input)        
        out = self.activation(out)
        return out


class Encoder(nn.Module):
    def __init__(self, hidden_dim_list):
        super(Encoder, self).__init__()
        
        layer_list = []
        for i in range(len(hidden_dim_list)-2):
            layer_list.append(midlayer(hidden_dim_list[i], hidden_dim_list[i+1]))
        
        layer_list.append(nn.Linear(hidden_dim_list[i+1], hidden_dim_list[i+2]))
        self.layer = nn.Sequential(*layer_list)

    def forward(self, input):
        out = self.layer(input)
        return out
    
class Decoder(nn.Module):
    def __init__(self, hidden_dim_list):
        super(Decoder, self).__init__()

        layer_list = []
        for i in range(len(hidden_dim_list)-2):
            layer_list.append(midlayer(hidden_dim_list[i], hidden_dim_list[i+1]))
        
        layer_list.append(midlayer(hidden_dim_list[i+1], hidden_dim_list[i+2]))
        self.layer = nn.Sequential(*layer_list)
    
    def forward(self, input):
        out = self.layer(input)
        return out

class CompressionNet(nn.Module):
    def __init__(self, enc_hidden_dim_list, dec_hidden_dim_list):
        super().__init__()
        self.encoder = Encoder(enc_hidden_dim_list)
        self.decoder = Decoder(dec_hidden_dim_list)

        self._reconstruction_loss = nn.MSELoss()

    def forward(self, input):
        out = self.encoder(input)
        out = self.decoder(out)
        return out

    def encode(self, input):
        return self.encoder(input)

    def decode(self, input):
        return self.decoder(input)

    def reconstuction_loss(self, input, input_target):
        target_hat = self(input)
        return self._reconstruction_loss(target_hat, input_target)
### reconstructed error
eps = torch.autograd.Variable(torch.FloatTensor([1.e-8]), requires_grad=False)

def relative_euclidean_distance(x1, x2, eps=eps):
    num = torch.norm(x1 - x2, p=2, dim=1)
    denom = torch.norm(x1, p=2, dim=1)
    return num / torch.max(denom, eps)

def cosine_similarity(x1, x2, eps=eps):
    dot_prod = torch.sum(x1 * x2, dim=1)
    dist_x1 = torch.norm(x1, p=2, dim=1)
    dist_x2 = torch.norm(x2, p=2, dim=1)
    return dot_prod / torch.max(dist_x1*dist_x2, eps)
### estimation network
class Estimation(nn.Module):
    def __init__(self, est_hidden_dim_list):
        super().__init__()
        
        layer_list = []
        for i in range(len(est_hidden_dim_list)-2):
            layer_list.append(midlayer(est_hidden_dim_list[i], est_hidden_dim_list[i+1]))
        
        layer_list.append(nn.Dropout(p=0.5))
        layer_list.append(nn.Linear(est_hidden_dim_list[-2], est_hidden_dim_list[-1]))
        layer_list.append(nn.Softmax())
        self.net = nn.Sequential(*layer_list)
        
    def forward(self, input):
        out = self.net(input)
        return out
### Mixture
class Mixture(nn.Module):
    def __init__(self, latent_dimension):
        super().__init__()
        self.latent_dimension = latent_dimension

        self.Phi    = np.random.random([1])
        self.Phi    = torch.from_numpy(self.Phi).float()
        self.Phi    = nn.Parameter(self.Phi, requires_grad = False)

        self.mu     = 2.*np.random.random([latent_dimension]) - 0.5
        self.mu     = torch.from_numpy(self.mu).float()
        self.mu     = nn.Parameter(self.mu, requires_grad = False)

        self.Sigma  = np.eye(latent_dimension, latent_dimension)
        self.Sigma  = torch.from_numpy(self.Sigma).float()
        self.Sigma  = nn.Parameter(self.Sigma, requires_grad = False)
        
        self.eps_Sigma  = torch.FloatTensor(np.diag([1.e-8 for _ in range(latent_dimension)]))

    def forward(self, est_inputs, with_log = True):
        batch_size, _   = est_inputs.shape
        out_values  = []
        inv_sigma   = torch.inverse(self.Sigma)
        det_sigma   = np.linalg.det(self.Sigma.data.cpu().numpy())
        det_sigma   = torch.from_numpy(det_sigma.reshape([1])).float()
        det_sigma   = torch.autograd.Variable(det_sigma)
        for est_input in est_inputs:
            diff    = (est_input - self.mu).view(-1,1)
            out     = -0.5 * torch.mm(torch.mm(diff.view(1,-1), inv_sigma), diff)
            out     = (self.Phi * torch.exp(out)) / torch.sqrt(2. * np.pi * det_sigma)
            if with_log:
                out = -torch.log(out)
            out_values.append(float(out.data.cpu().numpy()))

        out = torch.autograd.Variable(torch.FloatTensor(out_values))
        return out
    
    def _update_parameters(self, samples, affiliations):
        if not self.training:
            return

        batch_size, _ = samples.shape

        # Updating phi.
        phi = torch.mean(affiliations)
        self.Phi.data = phi.data

        # Updating mu.
        num = 0.
        for i in range(batch_size):
            z_i     = samples[i, :]
            gamma_i = affiliations[i]
            num     += gamma_i * z_i
        
        denom        = torch.sum(affiliations)
        self.mu.data = (num / denom).data

        # Updating Sigma.
        mu  = self.mu
        num = None
        for i in range(batch_size):
            z_i     = samples[i, :]
            gamma_i = affiliations[i]
            diff    = (z_i - mu).view(-1, 1)
            to_add  = gamma_i * torch.mm(diff, diff.view(1, -1))
            if num is None:
                num = to_add
            else:
                num += to_add

        denom           = torch.sum(affiliations)
        self.Sigma.data = (num / denom).data + self.eps_Sigma


class GMM(nn.Module):
    def __init__(self, num_mixtures, latent_dimension):
        super().__init__()
        self.num_mixtures       = num_mixtures
        self.latent_dimension   = latent_dimension

        mixtures        = [Mixture(latent_dimension) for _ in range(num_mixtures)]
        self.mixtures   = nn.ModuleList(mixtures)
    
    def forward(self, est_inputs):
        out = None
        for mixture in self.mixtures:
            to_add  = mixture(est_inputs, with_log = False)
            if out is None:
                out = to_add
            else:
                out += to_add
        return -torch.log(out)
    
    def _update_mixtures_parameters(self, samples, mixtures_affiliations):
        if not self.training:
            return

        for i, mixture in enumerate(self.mixtures):
            affiliations = mixtures_affiliations[:, i]
            mixture._update_parameters(samples, affiliations)
### model
class DAGMM(nn.Module):
    def __init__(self, compression_module, estimation_module, gmm_module):
        super().__init__()

        self.compressor = compression_module
        self.estimator  = estimation_module
        self.gmm        = gmm_module

    def forward(self, input):
        encoded = self.compressor.encode(input)
        decoded = self.compressor.decode(encoded)

        relative_ed     = relative_euclidean_distance(input, decoded)
        cosine_sim      = cosine_similarity(input, decoded)

        relative_ed     = relative_ed.view(-1, 1)
        cosine_sim      = relative_ed.view(-1, 1)
        latent_vectors  = torch.cat([encoded, relative_ed, cosine_sim], dim=1)

        if self.training:
            mixtures_affiliations = self.estimator(latent_vectors)
            self.gmm._update_mixtures_parameters(latent_vectors,
                                                 mixtures_affiliations)
        return self.gmm(latent_vectors)


class DAGMMArrhythmia(DAGMM):
    def __init__(self, enc_hidden_dim_list, dec_hidden_dim_list, est_hidden_dim_list):
        compressor  = CompressionNet(enc_hidden_dim_list, dec_hidden_dim_list)
        estimator   = Estimation(est_hidden_dim_list)
        gmm = GMM(num_mixtures=2, latent_dimension=4)

        super().__init__(compression_module = compressor,
                         estimation_module  = estimator,
                         gmm_module         = gmm)
### tests
def test_dagmm():
    net = DAGMMArrhythmia(args.enc_hidden_dim_list, args.dec_hidden_dim_list, args.est_hidden_dim_list)
    out = net(data_array)
    print(out)

def convert_to_var(input):
    out = torch.from_numpy(input).float()
    out = torch.autograd.Variable(out)
    return out

def test_update_mixture():
    batch_size       = 5
    latent_dimension = 7
    mix              = Mixture(latent_dimension)
    latent_vectors   = np.random.random([batch_size, latent_dimension])
    affiliations     = np.random.random([batch_size])
    latent_vectors   = convert_to_var(latent_vectors)
    affiliations     = convert_to_var(affiliations)

    for param in mix.parameters():
        print(param)

    mix.train()
    mix._update_parameters(latent_vectors, affiliations)

    for param in mix.parameters():
        print(param)


def test_forward_mixture():
    batch_size       = 5
    latent_dimension = 7

    mix = Mixture(latent_dimension)
    latent_vectors   = np.random.random([batch_size, latent_dimension])
    latent_vectors   = convert_to_var(latent_vectors)

    mix.train()
    out = mix(latent_vectors)
    print(out)


def test_update_gmm():
    batch_size      = int(5)
    latent_dimension= 7
    num_mixtures    = 2

    gmm = GMM(num_mixtures, latent_dimension)

    latent_vectors  = np.random.random([batch_size, latent_dimension])
    latent_vectors  = convert_to_var(latent_vectors)

    affiliations    = np.random.random([batch_size, num_mixtures])
    affiliations    = convert_to_var(affiliations)

    for param in gmm.parameters():
        print(param)

    gmm.train()
    gmm._update_mixtures_parameters(latent_vectors, affiliations)

    for param in gmm.parameters():
        print(param)
if __name__ == '__main__':
    test_update_mixture()
    test_forward_mixture()
    test_update_gmm()
    test_dagmm()
Parameter containing:
tensor([0.1173])
Parameter containing:
tensor([ 1.1475, -0.3912, -0.4815,  0.8666,  1.2454, -0.2445,  1.4350])
Parameter containing:
tensor([[1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1.]])
Parameter containing:
tensor(0.4673)
Parameter containing:
tensor([0.5283, 0.2903, 0.5008, 0.5089, 0.3621, 0.4749, 0.4967])
Parameter containing:
tensor([[ 0.0439,  0.0255, -0.0151,  0.0544,  0.0647,  0.0090,  0.0726],
        [ 0.0255,  0.0419, -0.0004,  0.0360,  0.0314,  0.0235,  0.0499],
        [-0.0151, -0.0004,  0.0240,  0.0005, -0.0159,  0.0520, -0.0083],
        [ 0.0544,  0.0360,  0.0005,  0.0954,  0.0900,  0.0640,  0.1063],
        [ 0.0647,  0.0314, -0.0159,  0.0900,  0.1013,  0.0334,  0.1124],
        [ 0.0090,  0.0235,  0.0520,  0.0640,  0.0334,  0.1681,  0.0645],
        [ 0.0726,  0.0499, -0.0083,  0.1063,  0.1124,  0.0645,  0.1351]])
tensor([4.8554, 4.6761, 4.4661, 4.8710, 4.8982])
Parameter containing:
tensor([0.9724])
Parameter containing:
tensor([ 0.3688,  0.9626,  0.7468, -0.3848,  0.9530,  0.7081,  0.9869])
Parameter containing:
tensor([[1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1.]])
Parameter containing:
tensor([0.3145])
Parameter containing:
tensor([1.4138, 0.4490, 1.3144, 0.2030, 1.1138, 0.8565, 0.0193])
Parameter containing:
tensor([[1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1.]])
Parameter containing:
tensor(0.6101)
Parameter containing:
tensor([0.2734, 0.7025, 0.5866, 0.5380, 0.4088, 0.3543, 0.5669])
Parameter containing:
tensor([[ 0.0827, -0.0232, -0.0216, -0.0052, -0.0009,  0.0318, -0.0217],
        [-0.0232,  0.0179,  0.0211,  0.0308,  0.0019, -0.0294,  0.0174],
        [-0.0216,  0.0211,  0.0355,  0.0366, -0.0125, -0.0311,  0.0099],
        [-0.0052,  0.0308,  0.0366,  0.0781,  0.0157, -0.0543,  0.0373],
        [-0.0009,  0.0019, -0.0125,  0.0157,  0.0748,  0.0095,  0.0413],
        [ 0.0318, -0.0294, -0.0311, -0.0543,  0.0095,  0.0582, -0.0249],
        [-0.0217,  0.0174,  0.0099,  0.0373,  0.0413, -0.0249,  0.0391]])
Parameter containing:
tensor(0.7920)
Parameter containing:
tensor([0.3664, 0.7017, 0.5854, 0.6048, 0.4577, 0.3511, 0.5929])
Parameter containing:
tensor([[ 0.1143, -0.0355, -0.0402, -0.0143,  0.0052,  0.0471, -0.0279],
        [-0.0355,  0.0251,  0.0335,  0.0389, -0.0113, -0.0418,  0.0165],
        [-0.0402,  0.0335,  0.0528,  0.0540, -0.0270, -0.0540,  0.0137],
        [-0.0143,  0.0389,  0.0540,  0.0882, -0.0128, -0.0710,  0.0290],
        [ 0.0052, -0.0113, -0.0270, -0.0128,  0.0901,  0.0411,  0.0349],
        [ 0.0471, -0.0418, -0.0540, -0.0710,  0.0411,  0.0829, -0.0191],
        [-0.0279,  0.0165,  0.0137,  0.0290,  0.0349, -0.0191,  0.0330]])
tensor([-16.9612, -17.8355, -17.1198, -17.8214, -16.9765, -13.4832, -17.8883,
        -16.8427, -16.7266, -17.4553, -16.4203, -16.9409, -17.6146, -15.1495,
        -15.9379, -14.9479, -16.1854, -17.5789, -16.6280, -16.3914, -17.8705,
        -16.4540, -17.6894, -17.6952, -17.9105, -16.9977, -16.9534, -17.4336,
        -17.0536, -16.7684, -16.9417, -16.7795, -17.7103, -16.8721, -16.0556,
        -16.6950, -17.5597, -17.4694, -17.2260, -16.6612, -17.3744, -16.9852,
        -15.9520, -15.9058, -16.7894, -16.7476, -16.5294, -17.4851, -17.1710,
        -17.6680, -17.7744, -17.4803, -16.1885, -14.6914, -17.6060, -17.8666,
        -15.9840, -16.8587, -17.0594, -15.2725, -12.6417, -16.6181, -16.4920,
        -16.8071, -17.3286, -17.1299, -17.0695, -15.1893, -16.3951, -16.9207,
        -17.8430, -17.5642, -17.4121, -17.1715, -16.3110, -16.8747,  -9.4447,
        -17.1424, -16.8413, -15.5074, -16.9420, -16.5711, -16.9047, -17.0329,
        -16.1576, -12.1645, -15.3441, -17.2765, -13.0270, -17.2411, -16.8029,
        -17.0612, -16.9580, -16.4213, -16.3675, -17.3942, -16.8937, -16.7398,
        -16.9648, -14.1587, -17.4364, -15.2745, -16.3091, -16.2499, -16.8516,
        -17.1280, -13.5346, -17.6581, -13.9843, -17.1864, -16.6246, -16.5804,
        -16.8307, -15.3374, -17.8115, -17.0566, -16.4458, -15.4481, -17.3106,
        -16.7912, -17.0978, -16.9524, -16.8946, -16.9953, -16.8141, -16.8253,
        -17.8658, -15.8182, -17.7361, -17.1825, -17.2128, -17.8637, -15.8899,
        -13.8260, -17.1471, -16.3404, -17.5933, -16.9136, -17.3709, -17.0664,
        -16.5006, -12.6482, -16.4738, -16.5775, -16.7379, -17.5284, -16.8920,
        -17.3995, -17.0317, -17.3837, -17.0141, -15.2970, -16.1848, -16.9914,
        -17.8913, -17.4832, -17.8995, -16.2905, -15.0831, -16.9196, -17.1523,
        -16.7555, -17.1981, -15.8082, -16.2761, -17.2934, -16.5064, -16.0278,
        -17.3910, -17.2586, -15.3088, -16.0010, -17.4784, -17.1952, -16.1224,
        -17.5613, -16.2742, -17.6207, -16.7494, -16.2702, -17.0538, -16.6950,
        -17.7900, -17.1539, -15.8621, -16.9515, -16.4764, -16.7403, -15.7678,
        -15.0590, -17.3892, -16.1987, -14.6992, -16.0648, -17.3860, -17.0080,
        -17.8890, -16.5353, -15.8924, -16.6831, -17.8159, -17.6439, -16.6141,
        -17.5179, -11.3079, -14.3721, -16.9241, -11.8753, -12.8442, -17.4653,
        -17.2166, -16.8058, -17.2556, -15.4820, -17.3901, -16.7357, -16.1637,
        -12.8449, -14.0127, -17.8357, -17.0824, -16.0741, -16.9676, -17.7164,
        -16.6101, -15.9101, -16.6739, -16.7744, -15.0311, -16.1429, -17.6085,
        -17.2562, -17.5409, -16.1831, -15.0784, -15.8370, -17.1719, -16.1404,
        -17.5857, -17.6243, -17.8642, -16.6991, -16.5568, -16.3714, -16.9949,
        -17.1926, -16.9671, -16.8230, -15.7120, -16.2652, -17.3596, -16.8273,
        -16.4941, -14.3067, -15.1746, -16.7538, -16.6177, -14.5924, -17.4868,
        -16.9961, -15.2027, -16.1179, -17.8136, -16.9221, -17.4402, -17.1182,
        -17.0236, -16.9496, -16.7210, -17.2224, -16.8340, -14.6614, -17.7063,
        -17.6374, -17.6177, -17.2496, -16.2857, -16.8705, -16.0419, -16.8412,
        -17.6452, -13.3405, -17.0936, -17.2620, -14.8293, -16.4840, -17.7907,
        -17.6367, -16.9794, -17.4806, -16.8857, -17.2911, -16.7659, -15.7899,
        -16.0170, -14.3231, -16.3894, -16.0958, -13.7606, -15.9562, -16.4767,
        -17.6906, -17.0464, -15.9133, -16.8948, -16.1856, -15.6871, -17.4597,
        -15.1920, -17.2234, -14.9576, -17.7829, -17.8947, -16.1404, -17.2103,
        -16.7290, -15.5131, -17.3531, -15.0447, -17.8739, -15.7056, -16.7523,
        -15.4276, -16.6752, -16.6374, -17.0997, -17.2698, -16.1508, -16.9022,
        -15.6754, -17.2008, -17.5945, -16.4966, -17.2604, -15.9539, -17.1234,
        -17.6056, -17.3484, -16.1133, -17.3953, -16.8830, -16.6132, -17.1589,
        -17.1432, -17.6974, -17.3527, -16.8420, -16.8880,  -7.7879, -16.9901,
        -15.8571, -15.3417, -16.7466, -14.3681, -13.3931, -16.2099, -15.7932,
        -16.7962, -17.4833, -13.5962, -17.7920, -15.0972, -16.2645, -15.1858,
        -16.6797, -16.3931, -17.5987, -16.3806, -16.6691, -17.5670, -13.5660,
        -16.1329, -16.6498, -17.7633, -11.9445, -17.6930, -16.2073, -17.2636,
        -17.8689, -15.8726, -16.3609, -16.1388, -16.8463, -16.7943, -16.9481,
        -17.1803, -13.8762, -16.2069, -16.5714, -13.3702, -16.9707, -15.1451,
        -17.7622, -10.2545, -16.9429, -12.9870, -16.5277, -16.6480, -15.0934,
        -17.5117, -17.1645, -17.1222, -17.2802, -15.8978, -15.6626, -17.8973,
        -15.4456, -14.8156, -16.5574, -16.9696, -15.0446, -14.8778, -16.0970,
        -17.7404, -16.6000, -16.7870, -15.6458, -14.8563, -15.2952, -16.9022,
        -17.1012, -17.8283, -16.9903, -16.9610, -11.8376, -16.1911, -17.2052,
        -17.0333, -16.5970, -16.3592, -16.3686, -17.7152, -16.5726, -16.6921,
        -16.8786, -17.3966, -16.6916, -16.8164, -16.4075, -16.7753, -16.3588,
        -15.2915, -17.0675, -17.2979, -15.7603, -17.8791, -15.6601, -16.5525,
        -16.0544, -16.3048, -14.6642, -14.9261])
C:\Users\UOS\anaconda3\Lib\site-packages\torch\nn\modules\container.py:217: UserWarning:

Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.

Ref

  • https://openreview.net/forum?id=BJJLHbb0-